import os
import sys
import tempfile
import shutil
import pybedtools
import pysam
from Bio.Seq import reverse_complement

dataset = sys.argv[1]
library = sys.argv[2]

assembly = "hg38"

unannotated = ('mRNA', 'lncRNA', 'gencode', 'fantomcat', 'genome')

def read_precursors(category):
    distance  = 500
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/Filters"
    filename = "%s.bed" % category
    name = "pre%s" % category
    path = os.path.join(directory, filename)
    lines = pybedtools.BedTool(path)
    for line in lines:
        line.start = max(0, line.start-distance)
        line.end += distance
        line.name = name
        fields = line.fields
        yield pybedtools.create_interval_from_list(fields[:6])

def parse_bamfile(library):
    filename = "%s.bam" % library
    print("Reading", filename)
    lines = pysam.Samfile(filename)
    current = None
    for line in lines:
        if line.is_unmapped:
            continue
        target = line.get_tag("XT")
        if target not in unannotated:
            continue
        try:
            value = line.get_tag("XA")
        except KeyError:
            pass
        else:
            continue
        chromosome = line.reference_name
        start = line.pos
        end = line.aend
        name = line.query_name
        if line.is_reverse:
            strand = "-"
        else:
            strand = "+"
        if name != current:
            current = name
            number = 0
        else:
            number += 1
        score = str(number)
        fields = [chromosome, start, end, name, score, strand]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval

def parse_miseq_bamfile(library):
    filename = "%s.bam" % library
    print("Reading", filename)
    lines = pysam.Samfile(filename)
    current = None
    for line1 in lines:
        line2 = next(lines)
        if line1.is_unmapped:
            assert line2.is_unmapped
            continue
        start1 = line1.reference_start
        end1 = line1.reference_end
        start2 = line2.reference_start
        end2 = line2.reference_end
        assert start1 < end1
        assert start2 < end2
        if line1.is_reverse:
            assert not line2.is_reverse
            start = start2
            end = end1
        else:
            if not line2.is_reverse:
                print(line1)
                print(line2)
            assert line2.is_reverse
            start = start1
            end = end2
        target = line1.get_tag("XT")
        if target not in unannotated:
            continue
        try:
            value = line1.get_tag("XA")
        except KeyError:
            pass
        else:
            continue
        chromosome = line1.reference_name
        name = line1.query_name
        if line1.is_reverse:
            strand = "-"
        else:
            strand = "+"
        if name != current:
            current = name
            number = 0
        else:
            number += 1
        score = str(number)
        fields = [chromosome, start, end, name, score, strand]
        interval = pybedtools.create_interval_from_list(fields)
        yield interval


chrom_sizes_path = "/osc-fs_home/scratch/mdehoon/Data/Genomes/hg38/hg38.chrom.sizes"

categories = ("snRNA",
              "tRNA",
              "snoRNA",
              "scaRNA",
             )

for category in categories:
    if dataset == "MiSeq":
        alignments = parse_miseq_bamfile(library)
    else:
        alignments = parse_bamfile(library)
    alignments = pybedtools.BedTool(alignments)
    alignments = alignments.sort(g=chrom_sizes_path)
    associations = {}
    annotations = read_precursors(category)
    annotations = pybedtools.BedTool(annotations)
    annotations = annotations.sort(g=chrom_sizes_path)
    # require same strand
    overlap = alignments.intersect(annotations, wb=True, s=True)
    for line in overlap:
        fields = line.fields
        assert len(fields) == 12
        alignment = pybedtools.create_interval_from_list(fields[:6])
        annotation = pybedtools.create_interval_from_list(fields[6:])
        name = alignment.name
        number = int(alignment.score)
        if name not in associations:
            associations[name] = {}
        associations[name][number] = annotation.name
    print("%s: found %d new annotations" % (category, len(associations)))
    filename = "%s.bam" % library
    print("Reading %s" % filename)
    alignments = pysam.Samfile(filename)
    current = ""
    stream = tempfile.NamedTemporaryFile(delete=False)
    stream.close()
    print("Writing %s" % stream.name)
    output = pysam.Samfile(stream.name, "wb", template=alignments)
    skip = 0
    for alignment in alignments:
        name = alignment.query_name
        if name == current:
            number += 1
        else:
            current = name
            number = 0
            sequence = alignment.query_sequence
            is_secondary = False
            if alignment.is_reverse:
                try:
                    sequence = reverse_complement(sequence)
                except Exception:
                    print(name)
                    raise
        if not alignment.is_unmapped:
            target = alignment.get_tag("XT")
            if target in unannotated:
                current_associations = associations.get(name)
                if current_associations is not None:
                    annotation = current_associations.get(number)
                    if annotation is None:
                        # other mapping locations of this read are annotated,
                        # but the current mapping location is not annotated
                        skip += 1
                        if dataset == "MiSeq":
                            alignment = next(alignments)
                            assert alignment.query_name == current
                        continue
                    try:
                        alignment.get_tag("XA")
                    except KeyError:
                        pass
                    else:
                        raise Exception("found existing annotation tag XA")
                    alignment.set_tag("XA", annotation)
        if not is_secondary:
            if alignment.is_reverse:
                sequence = reverse_complement(sequence)
            alignment.query_sequence = sequence
        alignment.is_secondary = is_secondary
        output.write(alignment)
        if dataset == "MiSeq":
            alignment = next(alignments)
            assert alignment.query_name == current
            alignment.is_secondary = is_secondary
            output.write(alignment)
        is_secondary = True
    output.close()
    alignments.close()
    print("Number of removed lines: %d" % skip)
    print("Moving %s to %s" % (stream.name, filename))
    shutil.move(stream.name, filename)
